import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from utils.utils_loss import logistic_loss
from utils.utils_models import linear_model, mlp_model
    
 ####function  of mean teacher    alpha =0.97
def update_ema(model, ema_model, alpha, global_step):
    alpha = min(1 - 1 / (global_step +1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(1-alpha, param.data)
        
def exp_rampup(rampup_length):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    def warpper(epoch):
        if epoch < rampup_length:
            epoch = np.clip(epoch, 0.0, rampup_length)
            phase = 1.0 - epoch / rampup_length
            return float(np.exp(-5.0 * phase * phase))
        else:
            return 1.0
    return warpper

